
from cbml_benchmark.modeling.registry import HEADS

from .linear_norm import LinearNorm,LinearNormMap


def build_linear_head(cfg):
    assert cfg.MODEL.HEAD.NAME0 in HEADS, f"head {cfg.MODEL.HEAD.NAME0} is not defined"
    return HEADS[cfg.MODEL.HEAD.NAME0](cfg, in_channels = cfg.MODEL.HEAD.trans_embdding_dim)

def build_ra_head2(cfg):
    assert cfg.MODEL.HEAD.NAME3 in HEADS, f"head {cfg.MODEL.HEAD.NAME3} is not defined"
    return HEADS[cfg.MODEL.HEAD.NAME3](in_channels=cfg.MODEL.HEAD.trans_in_channels[3],embedding_dim=cfg.MODEL.HEAD.trans_embdding_dim,num_mat = 1)

def build_ra_head3(cfg):
    assert cfg.MODEL.HEAD.NAME3 in HEADS, f"head {cfg.MODEL.HEAD.NAME3} is not defined"
    return HEADS[cfg.MODEL.HEAD.NAME3](in_channels=cfg.MODEL.HEAD.trans_in_channels[3],embedding_dim=cfg.MODEL.HEAD.trans_embdding_dim,num_mat = 64)

